{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "import os\n",
    "import scipy.io as sio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skdemo import imshow_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_images_from_folder(folder):\n",
    "    images = []\n",
    "    for filename in os.listdir(folder):\n",
    "        f = os.path.join(folder,filename)\n",
    "        img =cv2.imread(f, cv2.IMREAD_GRAYSCALE)\n",
    "        if img is not None:\n",
    "            images.append(img)\n",
    "    return images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path1 = 'sample3_left_edge' # folder of the frist CT volumn\n",
    "path2 = 'sample3_right_edge'# folder of the second CT volumn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_left = load_images_from_folder(path1)\n",
    "data_right = load_images_from_folder(path2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_left = np.swapaxes(data_left,1, 2)\n",
    "data_right = np.swapaxes(data_right,1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1,2, figsize = (10,10))\n",
    "axs[0].imshow(data_left[:,-200:,9], cmap = 'gray')\n",
    "axs[1].imshow(data_right[:,0:200,20], cmap = 'gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image0 = data_left[:,-200:,9]\n",
    "image1 = data_right[:,0:200,20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#fig, axs = plt.subplots(2,1, figsize = (5,10))\n",
    "ax1 = plt.subplot(211)\n",
    "sns.distplot(image0.ravel(),ax=ax1)\n",
    "ax2 = plt.subplot(212, sharex=ax1)\n",
    "sns.distplot(image1.ravel(),ax=ax2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _match_cumulative_cdf(source, template):\n",
    "    \"\"\"\n",
    "    Return modified source array so that the cumulative density function of\n",
    "    its values matches the cumulative density function of the template.\n",
    "    \"\"\"\n",
    "    src_values, src_unique_indices, src_counts = np.unique(source.ravel(),\n",
    "                                                           return_inverse=True,\n",
    "                                                           return_counts=True)\n",
    "    tmpl_values, tmpl_counts = np.unique(template.ravel(), return_counts=True)\n",
    "\n",
    "    # calculate normalized quantiles for each array\n",
    "    src_quantiles = np.cumsum(src_counts) / source.size\n",
    "    tmpl_quantiles = np.cumsum(tmpl_counts) / template.size\n",
    "\n",
    "    interp_a_values = np.interp(src_quantiles, tmpl_quantiles, tmpl_values)\n",
    "    return interp_a_values[src_unique_indices].reshape(source.shape)\n",
    "\n",
    "\n",
    "\n",
    "def match_histograms(image, reference, multichannel=False):\n",
    "    \"\"\"Adjust an image so that its cumulative histogram matches that of another.\n",
    "\n",
    "    The adjustment is applied separately for each channel.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    image : ndarray\n",
    "        Input image. Can be gray-scale or in color.\n",
    "    reference : ndarray\n",
    "        Image to match histogram of. Must have the same number of channels as\n",
    "        image.\n",
    "    multichannel : bool, optional\n",
    "        Apply the matching separately for each channel.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    matched : ndarray\n",
    "        Transformed input image.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        Thrown when the number of channels in the input image and the reference\n",
    "        differ.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] http://paulbourke.net/miscellaneous/equalisation/\n",
    "\n",
    "    \"\"\"\n",
    "    shape = image.shape\n",
    "    image_dtype = image.dtype\n",
    "\n",
    "    if image.ndim != reference.ndim:\n",
    "        raise ValueError('Image and reference must have the same number of channels.')\n",
    "\n",
    "    if multichannel:\n",
    "        if image.shape[-1] != reference.shape[-1]:\n",
    "            raise ValueError('Number of channels in the input image and reference '\n",
    "                             'image must match!')\n",
    "\n",
    "        matched = np.empty(image.shape, dtype=image.dtype)\n",
    "        for channel in range(image.shape[-1]):\n",
    "            matched_channel = _match_cumulative_cdf(image[..., channel], reference[..., channel])\n",
    "            matched[..., channel] = matched_channel\n",
    "    else:\n",
    "        matched = _match_cumulative_cdf(image, reference)\n",
    "\n",
    "    return matched"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "image = image0\n",
    "reference = image1\n",
    "matched = match_histograms(image, reference, multichannel=False)\n",
    "\n",
    "fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(8, 3),\n",
    "                                    sharex=True, sharey=True)\n",
    "for aa in (ax1, ax2, ax3):\n",
    "    aa.set_axis_off()\n",
    "\n",
    "ax1.imshow(image)\n",
    "ax1.set_title('Source')\n",
    "ax2.imshow(reference)\n",
    "ax2.set_title('Reference')\n",
    "ax3.imshow(matched)\n",
    "ax3.set_title('Matched')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#fig, axs = plt.subplots(2,1, figsize = (5,10))\n",
    "ax1 = plt.subplot(211)\n",
    "sns.distplot(reference.ravel(),ax=ax1)\n",
    "ax2 = plt.subplot(212, sharex=ax1)\n",
    "sns.distplot(matched.ravel(),ax=ax2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image0 = matched.astype(np.uint8)\n",
    "image0 = np.divide(image0,2)\n",
    "image1 = reference.astype(np.uint8)\n",
    "image1 = np.divide(image1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#fig, axs = plt.subplots(2,1, figsize = (5,10))\n",
    "ax1 = plt.subplot(121)\n",
    "plt.imshow(image0)\n",
    "ax2 = plt.subplot(122, sharex=ax1)\n",
    "plt.imshow(image1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.feature import ORB, match_descriptors\n",
    "\n",
    "orb = ORB(n_keypoints=1000, fast_threshold=0.05)\n",
    "\n",
    "orb.detect_and_extract(image0)\n",
    "keypoints1 = orb.keypoints\n",
    "descriptors1 = orb.descriptors\n",
    "\n",
    "orb.detect_and_extract(image1)\n",
    "keypoints2 = orb.keypoints\n",
    "descriptors2 = orb.descriptors\n",
    "\n",
    "matches12 = match_descriptors(descriptors1, descriptors2, cross_check=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.feature import plot_matches\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "plot_matches(ax, image0, image1, keypoints1, keypoints2, matches12)\n",
    "ax.axis('off');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.transform import ProjectiveTransform\n",
    "from skimage.measure import ransac\n",
    "from skimage.feature import plot_matches\n",
    "\n",
    "# Select keypoints from the source (image to be registered)\n",
    "# and target (reference image)\n",
    "src = keypoints2[matches12[:, 1]][:, ::-1]\n",
    "dst = keypoints1[matches12[:, 0]][:, ::-1]\n",
    "\n",
    "model_robust, inliers = ransac((src, dst), ProjectiveTransform,\n",
    "                               min_samples=4, residual_threshold=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(15, 15))\n",
    "plot_matches(ax, image0, image1, keypoints1, keypoints2, matches12[inliers])\n",
    "ax.axis('off');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.transform import SimilarityTransform\n",
    "\n",
    "r, c = image1.shape[:2]\n",
    "\n",
    "# Note that transformations take coordinates in (x, y) format,\n",
    "# not (row, column), in order to be consistent with most literature\n",
    "corners = np.array([[0, 0],\n",
    "                    [0, r],\n",
    "                    [c, 0],\n",
    "                    [c, r]])\n",
    "\n",
    "# Warp the image corners to their new positions\n",
    "warped_corners = model_robust(corners)\n",
    "\n",
    "# Find the extents of both the reference image and the warped\n",
    "# target image\n",
    "all_corners = np.vstack((warped_corners, corners))\n",
    "\n",
    "corner_min = np.min(all_corners, axis=0)\n",
    "corner_max = np.max(all_corners, axis=0)\n",
    "\n",
    "output_shape = (corner_max - corner_min)\n",
    "output_shape = np.ceil(output_shape[::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.color import gray2rgb\n",
    "from skimage.exposure import rescale_intensity\n",
    "from skimage.transform import warp\n",
    "\n",
    "offset = SimilarityTransform(translation=-corner_min)\n",
    "\n",
    "image0_ = warp(image0, offset.inverse,\n",
    "               output_shape=output_shape, cval=-1)\n",
    "\n",
    "image1_ = warp(image1, (model_robust + offset).inverse,\n",
    "               output_shape=output_shape, cval=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "imshow_all(image0_, image1_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.shape(image0_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_alpha(image, background=-1):\n",
    "    \"\"\"Add an alpha layer to the image.\n",
    "    \n",
    "    The alpha layer is set to 1 for foreground and 0 for background.\n",
    "    \"\"\"\n",
    "    return np.dstack((gray2rgb(image), (image != background)))\n",
    "\n",
    "image0_alpha = add_alpha(image0_)\n",
    "image1_alpha = add_alpha(image1_)\n",
    "\n",
    "merged = (image0_alpha + image1_alpha)\n",
    "alpha = merged[..., 3]\n",
    "\n",
    "# The summed alpha layers give us an indication of how many\n",
    "# images were combined to make up each pixel.  Divide by the\n",
    "# number of images to get an average.\n",
    "merged /= np.maximum(alpha, 1)[..., np.newaxis]\n",
    "merged = merged[..., :3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_alpha(image, background=-1):\n",
    "    \"\"\"Add an alpha layer to the image.\n",
    "    \n",
    "    The alpha layer is set to 1 for foreground and 0 for background.\n",
    "    \"\"\"\n",
    "    return np.dstack((image, (image != background)))\n",
    "\n",
    "image0_alpha = add_alpha(image0_)\n",
    "image1_alpha = add_alpha(image1_)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.shape(image0_alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged = (image0_alpha + image1_alpha)\n",
    "merged = (image1_alpha[:,:,0]+image0_alpha[:,:,0])/(image1_alpha[:,:,1]+image0_alpha[:,:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(merged)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
